from plyny.presentation.table import Table

from mako.template import Template


def math_wrap(s):
    return '\ensuremath{%s}' % s


escape = '&#%_'


def encode(x):
    for c in escape:
        x = x.replace(c, r'\%s' % c)
    return x


class LatexTable(Table):
    MEAN = r'\ensuremath{\bar{\chi}}'
    Chi = r'\ensuremath{\Chi}'
    chi = r'\ensuremath{\chi}'
    Sigma = r'\ensuremath{\Sigma}'
    sigma = r'\ensuremath{\sigma}'

    def __init__(self, *header):
        super(LatexTable, self).__init__(*header)
        self.multirow = False
        self.row_colors = {}

    def set_row_color(self, row, color):
        self.row_colors[row] = color

    def _much_less_than(self, value):
        return r'\ensuremath{\ll %s}' % value

    def _scientific(self, index, fmt, value):
        fmtd = super(LatexTable, self)._scientific(index, fmt, value)
        if 'e-' not in fmtd and 'e+' not in fmtd:
            return fmtd

        toks = fmtd.split('e')
        return r'\ensuremath{%s \times 10^{%s}}' % tuple(toks)

    def render(self):
        rows = self._format()
        for i in xrange(len(rows)):
            rows[i] = [encode(x) for x in rows[i]]

        rows.insert(1, [])

        if len(self.paired) > 0:
            new_formats = []
            pranges = []
            for x in self.paired.iteritems():
                pranges.append(xrange(x[0], x[0] + x[1][0]))

            for j, f in enumerate(rows[0]):
                for p in pranges:
                    if j in p:
                        if j == 0:
                            buffer = '|'
                        else:
                            buffer = ''
                        if j == p[0]:
                            new_formats.append('\multicolumn{%d}{%sc|}{%s}' % (self.paired[j][0], buffer, self.paired[j][1]))
                        else:
                            pass
                        break
                else:
                    new_formats.append('\multirow{2}{*}{%s}' % f)
                    rows[0][j] = ''

            rows.insert(0, new_formats)

        multirows = {}
        last = None
        if self.multirow:
            for i, row in enumerate(rows[1:]):
                found = False
                if last is not None:
                    for j in xrange(len(row)):
                        if len(last) > j:
                            c1, c2 = row[j], last[j]
                            if c1 == c2:
                                multirows.setdefault(i - 1, []).append(j)
                                found = True
                if found:
                    last = None
                else:
                    last = row

            for i, row in enumerate(rows):
                if i - 1 in multirows:
                    m = multirows[i - 1]
                    for j in xrange(len(row)):
                        if j in m:
                            rows[i][j] = r'\multirow{2}{*}{%s}' % row[j]
                            rows[i + 1][j] = ' '

        if self.paired:
            cols = len(self.rows[1])
        else:
            cols = len(self.rows[0])

        return self.make_simple_table(rows, False, self.justifications, cols)

    def make_simple_table(self, rows, distinct_final_row=False,
            justifications=None, cols=None):

        # assume that all rows have same number of columns
        if not cols:
            cols = len(rows[0])
        if justifications is None:
            justifications = [1] * cols

        for i in xrange(len(justifications)):
            if justifications[i] == 0:
                justifications[i] = 'l'
            elif justifications[i] == 1:
                justifications[i] = 'c'
            elif justifications[i] == 2:
                justifications[i] = 'r'

        return Template('\\begin{tabular}{|%s|}' % '|'.join(justifications[i] for i
        in xrange(cols)) + """\hline
    % for i, row in enumerate(rows):
        % if len(row) == 0:
\hline
        % else:
            % if i - 2 in row_colors:
\\rowcolor{${row_colors[i - 2]}} \\
            % endif
            % for v in row[:-1]:
${v} & \\
            % endfor
${row[-1]} \\\\\\

        % endif
    % endfor
\hline
\end{tabular}
""").render(rows=rows, row_colors=self.row_colors, distinct_final_row=distinct_final_row)

if __name__ == '__main__':
    t = LatexTable('a', 'b')
    t.add(1, 2)
    t.add(3, 2)
    t.multirow = True
    t.set_row_color(0, 'light-gray')

    print t.render()

    #print make_simple_table([['a', 'b'], [1, 2]])
